To predict the coffee stock closing price using Facebook Prophet
%pip install prophet
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.offline as pyo
pyo.init_notebook_mode()
from prophet import Prophet
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")
Defaulting to user installation because normal site-packages is not writeableNote: you may need to restart the kernel to use updated packages. Requirement already satisfied: prophet in c:\users\gentb\appdata\roaming\python\python39\site-packages (1.1.1) Requirement already satisfied: numpy>=1.15.4 in c:\users\gentb\appdata\roaming\python\python39\site-packages (from prophet) (1.21.2) Requirement already satisfied: matplotlib>=2.0.0 in c:\programdata\anaconda3\lib\site-packages (from prophet) (3.5.1) Requirement already satisfied: python-dateutil>=2.8.0 in c:\programdata\anaconda3\lib\site-packages (from prophet) (2.8.2) Requirement already satisfied: holidays>=0.14.2 in c:\users\gentb\appdata\roaming\python\python39\site-packages (from prophet) (0.18) Requirement already satisfied: pandas>=1.0.4 in c:\users\gentb\appdata\roaming\python\python39\site-packages (from prophet) (1.1.5) Requirement already satisfied: setuptools>=42 in c:\programdata\anaconda3\lib\site-packages (from prophet) (61.2.0) Requirement already satisfied: wheel>=0.37.0 in c:\programdata\anaconda3\lib\site-packages (from prophet) (0.37.1) Requirement already satisfied: tqdm>=4.36.1 in c:\programdata\anaconda3\lib\site-packages (from prophet) (4.64.0) Requirement already satisfied: convertdate>=2.1.2 in c:\users\gentb\appdata\roaming\python\python39\site-packages (from prophet) (2.4.0) Requirement already satisfied: cmdstanpy>=1.0.4 in c:\users\gentb\appdata\roaming\python\python39\site-packages (from prophet) (1.0.8) Requirement already satisfied: LunarCalendar>=0.0.9 in c:\users\gentb\appdata\roaming\python\python39\site-packages (from prophet) (0.0.9) Requirement already satisfied: setuptools-git>=1.2 in c:\users\gentb\appdata\roaming\python\python39\site-packages (from prophet) (1.2) Requirement already satisfied: pymeeus<=1,>=0.3.13 in c:\users\gentb\appdata\roaming\python\python39\site-packages (from convertdate>=2.1.2->prophet) (0.5.12) Requirement already satisfied: hijri-converter in c:\users\gentb\appdata\roaming\python\python39\site-packages (from holidays>=0.14.2->prophet) (2.2.4) Requirement already satisfied: korean-lunar-calendar in c:\users\gentb\appdata\roaming\python\python39\site-packages (from holidays>=0.14.2->prophet) (0.3.1) Requirement already satisfied: ephem>=3.7.5.3 in c:\users\gentb\appdata\roaming\python\python39\site-packages (from LunarCalendar>=0.0.9->prophet) (4.1.4) Requirement already satisfied: pytz in c:\programdata\anaconda3\lib\site-packages (from LunarCalendar>=0.0.9->prophet) (2021.3) Requirement already satisfied: kiwisolver>=1.0.1 in c:\programdata\anaconda3\lib\site-packages (from matplotlib>=2.0.0->prophet) (1.3.2) Requirement already satisfied: pyparsing>=2.2.1 in c:\programdata\anaconda3\lib\site-packages (from matplotlib>=2.0.0->prophet) (3.0.4) Requirement already satisfied: packaging>=20.0 in c:\programdata\anaconda3\lib\site-packages (from matplotlib>=2.0.0->prophet) (21.3) Requirement already satisfied: fonttools>=4.22.0 in c:\programdata\anaconda3\lib\site-packages (from matplotlib>=2.0.0->prophet) (4.25.0) Requirement already satisfied: pillow>=6.2.0 in c:\programdata\anaconda3\lib\site-packages (from matplotlib>=2.0.0->prophet) (9.0.1) Requirement already satisfied: cycler>=0.10 in c:\programdata\anaconda3\lib\site-packages (from matplotlib>=2.0.0->prophet) (0.11.0) Requirement already satisfied: six>=1.5 in c:\programdata\anaconda3\lib\site-packages (from python-dateutil>=2.8.0->prophet) (1.16.0) Requirement already satisfied: colorama in c:\programdata\anaconda3\lib\site-packages (from tqdm>=4.36.1->prophet) (0.4.4)
df = pd.read_csv("E:\coffee.csv", parse_dates=['Date'])
df.head()
| Date | Open | High | Low | Close | Volume | Currency | |
|---|---|---|---|---|---|---|---|
| 0 | 2000-01-03 | 122.25 | 124.00 | 116.10 | 116.50 | 6640 | USD |
| 1 | 2000-01-04 | 116.25 | 120.50 | 115.75 | 116.25 | 5492 | USD |
| 2 | 2000-01-05 | 115.00 | 121.00 | 115.00 | 118.60 | 6165 | USD |
| 3 | 2000-01-06 | 119.00 | 121.40 | 116.50 | 116.85 | 5094 | USD |
| 4 | 2000-01-07 | 117.25 | 117.75 | 113.80 | 114.15 | 6855 | USD |
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 5746 entries, 0 to 5745 Data columns (total 7 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Date 5746 non-null datetime64[ns] 1 Open 5746 non-null float64 2 High 5746 non-null float64 3 Low 5746 non-null float64 4 Close 5746 non-null float64 5 Volume 5746 non-null int64 6 Currency 5746 non-null object dtypes: datetime64[ns](1), float64(4), int64(1), object(1) memory usage: 314.4+ KB
# 'Close' is the target variable
# drop 'Currency' column
df.drop(['Currency'], axis=1, inplace=True)
df.describe()
| Open | High | Low | Close | Volume | |
|---|---|---|---|---|---|
| count | 5746.000000 | 5746.000000 | 5746.000000 | 5746.000000 | 5746.000000 |
| mean | 127.267635 | 128.847034 | 125.784669 | 127.215567 | 8807.178907 |
| std | 50.569425 | 51.164948 | 49.851487 | 50.506519 | 9612.789034 |
| min | 41.500000 | 42.000000 | 41.500000 | 41.500000 | 0.000000 |
| 25% | 98.800000 | 100.112500 | 97.862500 | 98.650000 | 61.000000 |
| 50% | 120.400000 | 121.600000 | 118.950000 | 120.250000 | 7008.000000 |
| 75% | 144.800000 | 146.000000 | 143.000000 | 144.337500 | 14497.750000 |
| max | 305.300000 | 306.250000 | 304.000000 | 304.900000 | 62045.000000 |
# Resample the data on 'Close' price weekly
sale_weekly = df.resample('W', on='Date').mean()
# Resample the data on 'Close' price monthly
sale_monthly = df.resample('M', on='Date').mean()
# Visualization
fig, (ax1, ax2, ax3) = plt.subplots(3, figsize=(12,8))
sns.lineplot(x='Date', y='Close', data=df, ax=ax1)
ax1.set_title('Coffee Daily Close Price')
sns.lineplot(x='Date', y='Close', data=sale_weekly, ax=ax2)
ax2.set_title('Coffee Weekly Close Price')
sns.lineplot(x='Date', y='Close', data=sale_monthly, ax=ax3)
ax3.set_title('Coffee Monthly Close Price')
plt.tight_layout()
# Dataset for forecasting
coffee_price = df[['Date','Close']]
coffee_price = coffee_price.rename(columns = {'Date':'ds', 'Close':'y'})
# Splitting train-test sets
train_set = coffee_price[:-365]
test_set = coffee_price[-365:]
# Fit the model to the train data
m = Prophet(yearly_seasonality=True )
m.add_country_holidays(country_name='US')
m.fit(train_set)
18:19:40 - cmdstanpy - INFO - Chain [1] start processing 18:19:46 - cmdstanpy - INFO - Chain [1] done processing
<prophet.forecaster.Prophet at 0x1bb86685370>
# Forecast model
prophet_pred = m.predict(test_set)
# Plot the forecast results
pred_plot = m.plot(prophet_pred)
sns.scatterplot(x='ds', y='y', data=test_set, color = 'g', label='True Daily close price')
<AxesSubplot:xlabel='ds', ylabel='y'>
prophet_pred.head()
| ds | trend | yhat_lower | yhat_upper | trend_lower | trend_upper | Christmas Day | Christmas Day_lower | Christmas Day_upper | Christmas Day (Observed) | ... | weekly | weekly_lower | weekly_upper | yearly | yearly_lower | yearly_upper | multiplicative_terms | multiplicative_terms_lower | multiplicative_terms_upper | yhat | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 2021-03-25 | 101.894239 | 87.358103 | 119.462934 | 101.894239 | 101.894239 | 0.0 | 0.0 | 0.0 | 0.0 | ... | -0.884969 | -0.884969 | -0.884969 | 2.089961 | 2.089961 | 2.089961 | 0.0 | 0.0 | 0.0 | 103.099231 |
| 1 | 2021-03-26 | 101.874364 | 87.590018 | 119.470158 | 101.874364 | 101.874364 | 0.0 | 0.0 | 0.0 | 0.0 | ... | -1.077511 | -1.077511 | -1.077511 | 1.926802 | 1.926802 | 1.926802 | 0.0 | 0.0 | 0.0 | 102.723655 |
| 2 | 2021-03-29 | 101.814739 | 87.273833 | 116.874786 | 101.814739 | 101.814739 | 0.0 | 0.0 | 0.0 | 0.0 | ... | -1.095343 | -1.095343 | -1.095343 | 1.541970 | 1.541970 | 1.541970 | 0.0 | 0.0 | 0.0 | 102.261366 |
| 3 | 2021-03-30 | 101.794863 | 86.100350 | 118.323247 | 101.794863 | 101.794863 | 0.0 | 0.0 | 0.0 | 0.0 | ... | -0.865276 | -0.865276 | -0.865276 | 1.453662 | 1.453662 | 1.453662 | 0.0 | 0.0 | 0.0 | 102.383249 |
| 4 | 2021-03-31 | 101.774988 | 87.045318 | 117.011452 | 101.774988 | 101.774988 | 0.0 | 0.0 | 0.0 | 0.0 | ... | -0.829819 | -0.829819 | -0.829819 | 1.386737 | 1.386737 | 1.386737 | 0.0 | 0.0 | 0.0 | 102.331906 |
5 rows × 70 columns
# plotting on the predicted values from test set
prophet_plot = prophet_pred[['ds','trend','yearly','holidays','yhat']]
prophet_plot = prophet_plot.rename(columns={'yearly':'season'})
prophet_plot.head()
| ds | trend | season | holidays | yhat | |
|---|---|---|---|---|---|
| 0 | 2021-03-25 | 101.894239 | 2.089961 | 0.0 | 103.099231 |
| 1 | 2021-03-26 | 101.874364 | 1.926802 | 0.0 | 102.723655 |
| 2 | 2021-03-29 | 101.814739 | 1.541970 | 0.0 | 102.261366 |
| 3 | 2021-03-30 | 101.794863 | 1.453662 | 0.0 | 102.383249 |
| 4 | 2021-03-31 | 101.774988 | 1.386737 | 0.0 | 102.331906 |
prophet_plot = prophet_plot.sort_values(by='ds')
fig = make_subplots(rows=3, cols=1, subplot_titles=('trend','season','holidays'))
fig.add_trace(go.Scatter(x=prophet_plot['ds'], y=prophet_plot['trend']), row=1, col=1)
fig.add_trace(go.Scatter(x=prophet_plot['ds'], y=prophet_plot['season']), row=2, col=1)
fig.add_trace(go.Scatter(x=prophet_plot['ds'], y=prophet_plot['holidays']), row=3, col=1)
# Fit the model to the data
m = Prophet(yearly_seasonality=True)
m.add_country_holidays(country_name='US')
m.fit(coffee_price)
18:19:47 - cmdstanpy - INFO - Chain [1] start processing 18:19:53 - cmdstanpy - INFO - Chain [1] done processing
<prophet.forecaster.Prophet at 0x1bb8b1c4bb0>
# Predict for the next 365-day period
df_pred = m.make_future_dataframe(periods=365, include_history=True)
coffee_pred = m.predict(df_pred)
coffee_plot = coffee_pred[['ds','trend','yearly','holidays']]
coffee_plot = coffee_plot.rename(columns={'yearly':'season'})
coffee_plot.head()
| ds | trend | season | holidays | |
|---|---|---|---|---|
| 0 | 2000-01-03 | 106.606151 | 3.499964 | 0.0 |
| 1 | 2000-01-04 | 106.509236 | 3.641218 | 0.0 |
| 2 | 2000-01-05 | 106.412322 | 3.763193 | 0.0 |
| 3 | 2000-01-06 | 106.315408 | 3.865800 | 0.0 |
| 4 | 2000-01-07 | 106.218494 | 3.949306 | 0.0 |
coffee_plot = coffee_plot.sort_values(by='ds')
fig = make_subplots(rows=3, cols=1, subplot_titles=('trend','season','holidays'))
fig.add_trace(go.Scatter(x=coffee_plot['ds'], y=coffee_plot['trend']), row=1, col=1)
fig.add_trace(go.Scatter(x=coffee_plot['ds'], y=coffee_plot['season']), row=2, col=1)
fig.add_trace(go.Scatter(x=coffee_plot['ds'], y=coffee_plot['holidays']), row=3, col=1)
fig.show()
For the test set
m.plot_components(prophet_pred)
For the original set
m.plot_components(coffee_pred)
[1] https://www.theguardian.com/business/2011/apr/21/commodities-coffee-shortage-price-rise-expected